{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import time\n",
    "import imageio\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "\n",
    "from IPython.display import clear_output\n",
    "\n",
    "from lib.CAModel import CAModel\n",
    "from lib.utils_vis import SamplePool, to_alpha, to_rgb, get_living_mask, make_seed, make_circle_masks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_emoji(index, path=\"data/emoji.png\"):\n",
    "    im = imageio.imread(path)\n",
    "    emoji = np.array(im[:, index*40:(index+1)*40].astype(np.float32))\n",
    "    emoji /= 255.0\n",
    "    return emoji\n",
    "\n",
    "def visualize_batch(x0, x):\n",
    "    vis0 = to_rgb(x0)\n",
    "    vis1 = to_rgb(x)\n",
    "    print('batch (before/after):')\n",
    "    plt.figure(figsize=[15,5])\n",
    "    for i in range(x0.shape[0]):\n",
    "        plt.subplot(2,x0.shape[0],i+1)\n",
    "        plt.imshow(vis0[i])\n",
    "        plt.axis('off')\n",
    "    for i in range(x0.shape[0]):\n",
    "        plt.subplot(2,x0.shape[0],i+1+x0.shape[0])\n",
    "        plt.imshow(vis1[i])\n",
    "        plt.axis('off')\n",
    "    plt.show()\n",
    "\n",
    "def plot_loss(loss_log):\n",
    "    plt.figure(figsize=(10, 4))\n",
    "    plt.title('Loss history (log10)')\n",
    "    plt.plot(np.log10(loss_log), '.', alpha=0.1)\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda:0\")\n",
    "model_path = \"models/remaster_1.pth\"\n",
    "\n",
    "CHANNEL_N = 16        # Number of CA state channels\n",
    "TARGET_PADDING = 16   # Number of pixels used to pad the target image border\n",
    "TARGET_SIZE = 40\n",
    "\n",
    "lr = 2e-3\n",
    "lr_gamma = 0.9999\n",
    "betas = (0.5, 0.5)\n",
    "n_epoch = 80000\n",
    "\n",
    "BATCH_SIZE = 8\n",
    "POOL_SIZE = 1024\n",
    "CELL_FIRE_RATE = 0.5\n",
    "\n",
    "TARGET_EMOJI = 0 #@param \"🦎\"\n",
    "\n",
    "EXPERIMENT_TYPE = \"Regenerating\"\n",
    "EXPERIMENT_MAP = {\"Growing\":0, \"Persistent\":1, \"Regenerating\":2}\n",
    "EXPERIMENT_N = EXPERIMENT_MAP[EXPERIMENT_TYPE]\n",
    "\n",
    "USE_PATTERN_POOL = [0, 1, 1][EXPERIMENT_N]\n",
    "DAMAGE_N = [0, 0, 3][EXPERIMENT_N]  # Number of patterns to damage in a batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "target_img = load_emoji(TARGET_EMOJI)\n",
    "plt.figure(figsize=(4,4))\n",
    "plt.imshow(to_rgb(target_img))\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "p = TARGET_PADDING\n",
    "pad_target = np.pad(target_img, [(p, p), (p, p), (0, 0)])\n",
    "h, w = pad_target.shape[:2]\n",
    "pad_target = np.expand_dims(pad_target, axis=0)\n",
    "pad_target = torch.from_numpy(pad_target.astype(np.float32)).to(device)\n",
    "\n",
    "seed = make_seed((h, w), CHANNEL_N)\n",
    "pool = SamplePool(x=np.repeat(seed[None, ...], POOL_SIZE, 0))\n",
    "batch = pool.sample(BATCH_SIZE).x\n",
    "\n",
    "ca = CAModel(CHANNEL_N, CELL_FIRE_RATE, device).to(device)\n",
    "ca.load_state_dict(torch.load(model_path))\n",
    "\n",
    "optimizer = optim.Adam(ca.parameters(), lr=lr, betas=betas)\n",
    "scheduler = optim.lr_scheduler.ExponentialLR(optimizer, lr_gamma)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "loss_log = []\n",
    "\n",
    "def train(x, target, steps, optimizer, scheduler):\n",
    "    x = ca(x, steps=steps)\n",
    "    loss = F.mse_loss(x[:, :, :, :4], target)\n",
    "    optimizer.zero_grad()\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    scheduler.step()\n",
    "    return x, loss\n",
    "\n",
    "def loss_f(x, target):\n",
    "    return torch.mean(torch.pow(x[..., :4]-target, 2), [-2,-3,-1])\n",
    "\n",
    "for i in range(n_epoch+1):\n",
    "    if USE_PATTERN_POOL:\n",
    "        batch = pool.sample(BATCH_SIZE)\n",
    "        x0 = torch.from_numpy(batch.x.astype(np.float32)).to(device)\n",
    "        loss_rank = loss_f(x0, pad_target).detach().cpu().numpy().argsort()[::-1]\n",
    "        x0 = batch.x[loss_rank]\n",
    "        x0[:1] = seed\n",
    "        if DAMAGE_N:\n",
    "            damage = 1.0-make_circle_masks(DAMAGE_N, h, w)[..., None]\n",
    "            x0[-DAMAGE_N:] *= damage\n",
    "    else:\n",
    "        x0 = np.repeat(seed[None, ...], BATCH_SIZE, 0)\n",
    "    x0 = torch.from_numpy(x0.astype(np.float32)).to(device)\n",
    "\n",
    "    x, loss = train(x0, pad_target, np.random.randint(64,96), optimizer, scheduler)\n",
    "    \n",
    "    if USE_PATTERN_POOL:\n",
    "        batch.x[:] = x.detach().cpu().numpy()\n",
    "        batch.commit()\n",
    "\n",
    "    step_i = len(loss_log)\n",
    "    loss_log.append(loss.item())\n",
    "    \n",
    "    if step_i%100 == 0:\n",
    "        clear_output()\n",
    "        print(step_i, \"loss =\", loss.item())\n",
    "        visualize_batch(x0.detach().cpu().numpy(), x.detach().cpu().numpy())\n",
    "        plot_loss(loss_log)\n",
    "        torch.save(ca.state_dict(), model_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:py37_torch] *",
   "language": "python",
   "name": "conda-env-py37_torch-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
